package com.example.autumn.jdbc;

import com.example.autumn.exeception.DataAccessException;

import javax.sql.DataSource;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;

/**
 * @author liuzhiyong
 * @date 2023/11/3
 * description:
 */
public class JdbcTemplate {

    final DataSource dataSource;


    public JdbcTemplate(DataSource dataSource) {
        this.dataSource = dataSource;
    }

    /**
     * 执行数据库操作
     *
     * @param action 在连接内要执行的操作
     * @return {@link T } 结果
     * @author liuzhiyong
     * @date 2023/11/3
     */
    public <T> T execute(ConnectionCallback<T> action) {
        try (Connection newConn = dataSource.getConnection()) {
            final boolean autoCommit = newConn.getAutoCommit();
            if (!autoCommit) {
                newConn.setAutoCommit(true);
            }
            T result = action.doInConnection(newConn);
            if (!autoCommit) {
                newConn.setAutoCommit(false);
            }
            return result;
        } catch (SQLException e) {
            throw new DataAccessException(e);
        }
    }

    /**
     * 指定数据库操作
     *
     * @param psc 创建预编译对象的函数
     * @param action 利用预编译对象要执行的操作 函数
     * @return {@link T } 结果
     * @author liuzhiyong
     * @date 2023/11/3
     */
    public <T> T execute(PreparedStatementCreator psc, PreparedStatementCallback<T> action) {
        return execute(con -> {
            try (PreparedStatement ps = psc.createPreparedStatement(con)) {
                return action.doInPreparedStatement(ps);
            }
        });
    }

    /**
     * 执行更新操作
     *
     * @param sql sql
     * @param args 参数
     * @return {@link int }
     * @author liuzhiyong
     * @date 2023/11/3
     */
    public int update(String sql, Object... args) throws DataAccessException {
        return execute(preparedStatementCreator(sql, args), PreparedStatement::executeUpdate);
    }

    /**
     * 更新一行数据, 并返回数数字主键
     *
     * @param sql sql
     * @param args 参数
     * @return {@link Number } 主键
     * @author liuzhiyong
     * @date 2023/11/7
     */
    public Number updateAndReturnGeneratedKey(String sql, Object... args) throws DataAccessException {
        return updateAndReturnGeneratedKey(sql,NumberRowMapper.instance, args);
    }

    /**
     * 更新一行数据并返回字符串主键
     *
     * @param sql sql
     * @param args 参数
     * @return {@link String } 结果
     * @author liuzhiyong
     * @date 2023/11/7
     */
    public String updateAndReturnGeneratedStringKey(String sql, Object... args) throws DataAccessException {
        return updateAndReturnGeneratedKey(sql, StringRowMapper.instance, args);
    }

    /**
     * 更新一行数据, 并返回主键
     *
     * @param sql sql
     * @param rowMapper 主键映射函数
     * @param args 参数
     * @return {@link T } 主键
     * @author liuzhiyong
     * @date 2023/11/7
     */
    private <T> T updateAndReturnGeneratedKey(String sql,RowMapper<T> rowMapper, Object... args) throws DataAccessException {
        return execute(con -> {
            PreparedStatement ps = con.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS);
            bindArgs(ps, args);
            return ps;
        }, ps -> {
            int n = ps.executeUpdate();
            if (n == 0) {
                throw new DataAccessException("0 rows inserted.");
            }
            if (n > 1) {
                throw new DataAccessException("Multiple rows inserted.");
            }
            try (ResultSet keys = ps.getGeneratedKeys()) {
                if (keys.next()) {
                    return rowMapper.mapRow(keys, keys.getRow());
                }
            }
            throw new DataAccessException("Should not reach here.");
        });
    }


    /**
     * 查询结果为数字
     *
     * @param sql sql
     * @param args 参数
     * @return {@link Number } 结果
     * @author liuzhiyong
     * @date 2023/11/7
     */
    public Number queryForNumber(String sql, Object... args) throws DataAccessException {
        return queryForObject(sql, NumberRowMapper.instance, args);
    }

    /**
     * 查询单条数据
     *
     * @param sql sql字符串
     * @param rowMapper 结果映射函数
     * @param args 参数
     * @return {@link T } 结果
     * @author liuzhiyong
     * @date 2023/11/6
     */
    public <T> T queryForObject(String  sql, RowMapper<T> rowMapper, Object... args) {
        return execute(preparedStatementCreator(sql, args), ps -> {
            T t = null;
            try (ResultSet rs = ps.executeQuery()) {
                while (rs.next()) {
                    if (t == null) {
                        t = rowMapper.mapRow(rs, rs.getRow());
                    } else {
                        // 查询一条数据, 多条的话会抛出异常
                        throw new DataAccessException("Multiple rows found.");
                    }
                }
            }
            if (t == null) {
                throw new DataAccessException("Empty result set.");
            }
            return t;
        });
    }

    /**
     * 查询单条数据
     *
     * @param sql sql
     * @param clazz 返回类型的class
     * @param args 参数
     * @return {@link T } 结果
     * @author liuzhiyong
     * @date 2023/11/7
     */
    @SuppressWarnings("unchecked")
    public <T> T queryForObject(String sql, Class<T> clazz, Object... args) throws DataAccessException {
        if (clazz == String.class) {
            return (T) queryForObject(sql, StringRowMapper.instance, args);
        }
        if (clazz == Boolean.class || clazz == boolean.class) {
            return (T) queryForObject(sql, BooleanRowMapper.instance, args);
        }
        // 如果类是Number的子类型(包装类) 或者 是基本类型 (boolean、byte、short、int、long、float、double和char)
        if (Number.class.isAssignableFrom(clazz) || clazz.isPrimitive()) {
            return (T) queryForObject(sql, NumberRowMapper.instance, args);
        }
        return queryForObject(sql, new BeanRowMapper<>(clazz), args);
    }

    /**
     * 查询多条数据
     *
     * @param sql sql
     * @param clazz 结果类型
     * @param args 参数
     * @return {@link List<T> } 结果
     * @author liuzhiyong
     * @date 2023/11/7
     */
    public <T> List<T> queryForList(String sql, Class<T> clazz, Object... args) throws DataAccessException{
        return queryForList(sql, new BeanRowMapper<>(clazz), args);
    }

    /**
     * 查询 将结果映射成执行类型集合
     *
     * @param sql sql
     * @param rowMapper 记过映射函数
     * @param args 参数
     * @return {@link List<T> } 结果
     * @author liuzhiyong
     * @date 2023/11/7
     */
    public <T> List<T> queryForList(String sql, RowMapper<T> rowMapper, Object... args) throws DataAccessException {
        return execute(preparedStatementCreator(sql, args), ps -> {
            List<T> list = new ArrayList<>();
            try (ResultSet rs = ps.executeQuery()) {
                while (rs.next()) {
                    list.add(rowMapper.mapRow(rs, rs.getRow()));
                }
            }
            return list;
        });
    }

    /**
     * 生成创建预编译对象的函数
     *
     * @param sql sql
     * @param args 参数
     * @return {@link PreparedStatementCreator } 创建预编译对象的函数
     * @author liuzhiyong
     * @date 2023/11/3
     */
    private PreparedStatementCreator preparedStatementCreator(String sql, Object... args) {
        return con -> {
            PreparedStatement ps = con.prepareStatement(sql);
            bindArgs(ps, args);
            return ps;
        };
    }

    /**
     * 绑定预编译对象和参数
     * 也就是给预编译对象填充SQL参数
     *
     * @param ps SQL预编译对象
     * @param args 参数
     * @author liuzhiyong
     * @date 2023/11/3
     */
    private void bindArgs(PreparedStatement ps, Object... args) throws SQLException {
        for (int i = 0; i < args.length; i++) {
            ps.setObject(i + 1, args[i]);
        }
    }

}

class StringRowMapper implements RowMapper<String> {

    static StringRowMapper instance = new StringRowMapper();

    @Override
    public String mapRow(ResultSet rs, int rowNum) throws SQLException {
        return rs.getString(1);
    }

}

class BooleanRowMapper implements RowMapper<Boolean> {

    static BooleanRowMapper instance = new BooleanRowMapper();

    @Override
    public Boolean mapRow(ResultSet rs, int rowNum) throws SQLException {
        return rs.getBoolean(1);
    }

}

class NumberRowMapper implements RowMapper<Number> {

    static NumberRowMapper instance = new NumberRowMapper();

    @Override
    public Number mapRow(ResultSet rs, int rowNum) throws SQLException {
        return (Number) rs.getObject(1);
    }
}
